{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #C8E6C9; padding: 10px; color: #1b7678\">\n",
    "<b>Pre-requisites</b>: Intermediate knowledge of Deep Learning and basic knowledge of Tabular Problems like Regression and Classification. Also go through the <i>Approaching Any Tabular Problem with PyTorch Tabular</i> tutorial.  <br></br>\n",
    "<b>Level</b>: Intermediate\n",
    "</div>\n",
    "\n",
    "In the _Approaching Any Tabular Problem_\n",
    " with PyTorch Tabular, we saw how to start using PyTorch Tabular with it's intelligent defaults. In this tutorial, we will see how to leverage sightly advanced features of PyTorch Tabular to have more flexibility and typically better results. In this tutorial, we assume you already know how to use basic features of PyTorch Tabular. If you are not familiar with PyTorch Tabular, please go through the _Approaching Any Tabular Problem_ with PyTorch Tabular tutorial first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rich.pretty import pprint\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "First of all, let's create a synthetic data which is a mix of numerical and categorical features and have multiple targets for regression. It means that there are multiple columns which we need to predict with the same set of features. Most classical machine learning models (like the ones in scikit-learn) only handle single target problems. We will have to train different models for each target. While this is perfectly fine, it is not the most efficient way to do it. First of all, we will have to train multiple models which will take more time. Secondly, if the two targets have some relationship between them, we are not leveraging that information. For example, if we are predicting the price of a house and the area of the house, we know that the two targets are related. If we train two different models, we are not leveraging this information.\n",
    "\n",
    "PyTorch Tabular can handle **multi-target problems** out of the box (only for Regression currently). We just need to pass the list of target columns to the `target` parameter of the `DataConfig` class. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false",
    "id": "RasjIOvb8jEB"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from pytorch_tabular.utils import make_mixed_dataset, print_metrics\n",
    "data, cat_col_names, num_col_names = make_mixed_dataset(\n",
    "    task=\"regression\", n_samples=100000, n_features=20, n_categories=4, n_targets=2, random_state=42\n",
    ")\n",
    "target_cols = [\"target_0\", \"target_1\"]\n",
    "train, test = train_test_split(data, random_state=42)\n",
    "train, val = train_test_split(train, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false",
    "id": "CAz_0U8o8jEO"
   },
   "source": [
    "Let's import the required classes from `PyTorch Tabular`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "I4pWHyuH8jEO",
    "outputId": "20d05c98-cf34-41da-82bf-ced9a297ea26"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false",
    "id": "jSzCJ6Xq8jEP"
   },
   "source": [
    "We already know the basic steps to define a model in PyTorch Tabular. We need to define a few configs and initialize the `TabularModel`. Let's do that.\n",
    "\n",
    "But this time, let's look at some of the more advanced features of PyTorch Tabular.\n",
    "\n",
    "### 1. `DataConfig`\n",
    "\n",
    "We know we need to define the target, continuous and categorical columns in the `DataConfig`. But there are a few more parameters which we can use to customize the data processing pipeline. Let's look at a few of them.\n",
    "\n",
    "- `normalize_continuous`: For better optimization, DL models prefer normalized continuous features. By default, PyTorch Tabular normalizes the continuous features. But if you want to use a custom normalization, you can do the normalization outside of PyTorch Tabular and pass `normalize_continuous=False` to the `DataConfig`. PyTorch Tabular will not normalize the continuous features and will use the values as is.\n",
    "- `continuous_feature_transform`: Sometimes, we want to transform the continuous features before feeding them to the model. For example, we might want to take the log of a feature or take the square root of a feature, etc. PyTorch Tabular has a few such transformations built-in. You can pass the name of the transformation to the `continuous_feature_transform` parameter of the `DataConfig`. The allowable inputs are: `['quantile_normal', 'yeo-johnson', 'quantile_uniform', 'box-cox']`. This internally uses a few scikit-learn transformers to do the transformation. You can read about these [here](https://scikit-learn.org/stable/modules/preprocessing.html#non-linear-transformation)\n",
    "- `num_workers` and `pin_memory` are two parameters which are used to speed up the data loading process. If you are using a GPU, you can set `num_workers` to a number greater than 0 (Only for Linux). This will use multiple CPU cores to load the data in parallel. `pin_memory` is a parameter which is used to speed up the data transfer from CPU to GPU. If you are using a GPU, you can set `pin_memory=True` to speed up the data transfer. You can read more about these [here](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)\n",
    "\n",
    "For the entire list of parameters, please refer to the API Reference in the docs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=target_cols,\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    "    num_workers=10,\n",
    "    normalize_continuous_features=True,\n",
    "    continuous_feature_transform=\"quantile_normal\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. `TrainingConfig`\n",
    "\n",
    "Training a Deep Learning model can get arbritarily complex. PyTorch Tabular, by inheriting PyTorch Lightning, offloads the whole workload onto the underlying PyTorch Lightning Framework. In the basic tutorial, we just scratched the surface of what PyTorch Lightning can do. In this tutorial, we will see how to leverage some of the more advanced features of PyTorch Lightning as well as a few convenience features of PyTorch Tabular.\n",
    "\n",
    "We already know that we can pass the `max_epochs`, `batch_size` to `TrainerConfig`. Let's look at a few more parameters.\n",
    "\n",
    "**`accelerator`**  \n",
    "\n",
    "PyTorch Lightning supports training on multiple GPUs and TPUs. You can pass the accelerator type to the `accelerator` parameter of the `TrainerConfig`. The allowable inputs are: `['cpu','gpu','tpu','ipu','auto']`. `cpu` let's you train the model on CPUs. `gpu` let's you train the model on GPUs. `tpu` let's you train the model on TPUs. `ipu` let's you train the model on IPUs. `auto` let's PyTorch Lightning choose the best accelerator for you. You can read more about these [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html)\n",
    "\n",
    "**`devices` and `devices_list`**  \n",
    "\n",
    "`devices` let's you choose number of devices (CPU Cores, GPUs etc.) to train the model on. -1 means training on all available devices. `devices_list` let's you choose the specific devices to train the model on. For example, if you want to train the model on GPU 0 and 1, you can pass `devices_list=[0,1]`.\n",
    "\n",
    "**`min_epochs` and `max_time`**   \n",
    "\n",
    "These are also parameters that help you control the training, apart from `max_epochs`. `min_epochs` let's you specify the minimum number of epochs to train the model for (typically useful when you are using early stopping). `max_time` let's you specify the maximum time to train the model for.\n",
    "\n",
    "**`Early Stopping Parameters`**     \n",
    "\n",
    "PyTorch Lightning supports early stopping out of the box. Early stopping is a technique to stop the training process if the model is not improving by monitoring a loss/metric on the validation set.\n",
    "You can pass the following parameters to the `TrainerConfig` to use early stopping: \n",
    ">        `early_stopping`: The loss/metric to monitor for early stopping. If set to None, early stopping will not be used.   \n",
    ">        `early_stopping_min_delta`: The minimum change in the loss/metric that qualifies as an improvement for early stopping.   \n",
    ">        `early_stopping_mode`: The direction in which the loss/metric should be optimized. Choices are: ['max', 'min'].   \n",
    ">        `early_stopping_patience`: The number of epochs to wait until there is no further improvement in the loss/metric.   \n",
    ">        `early_stopping_kwargs`: Additional keyword arguments for the early stopping callback. Refer to the PyTorch Lightning EarlyStopping callback documentation for more details.   \n",
    ">        `load_best`: If True, loads the best model weights at the end of training. Defaults to True.\n",
    "\n",
    "**`Checkpoint Saving Parameters`**\n",
    "\n",
    "PyTorch Lightning supports saving the model checkpoints automatically. Checkpoint saving is a technique to save the model weights at regular intervals during training. This is useful in case the training process is interrupted due to some reason, or if we want to go back and use a weight from a previous epoch. Typically useful when using early stopping, so that we can roll back and use the best model weights. You can pass the following parameters to the `TrainerConfig` to save the model checkpoints:\n",
    ">      `checkpoints`: str: The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints    \n",
    ">      `checkpoints_path`: str: The path where the saved models will be. Defaults to `saved_models`    \n",
    ">      `checkpoints_mode`: str: The direction in which the loss/metric should be optimized. Choices are `max` and `min`. Defaults to `min`    \n",
    ">      `checkpoints_save_top_k`: int: The number of best models to save. If you want to save more than one best models, you can set this parameter to >1. Defaults to `1`    \n",
    "\n",
    "<div style=\"background-color: #fce3b3; padding: 10px; color: #e76013\">\n",
    "<b>Note</b>: Make sure the name of the metric/loss you want to track exactly matches the ones in the logs. Recommended way is to run a model and check the results by evaluating the model. From the resulting dictionary, you can pick up a key to track during training.</br>\n",
    "</div>\n",
    "\n",
    "**`Learning Rate Finder`**\n",
    "\n",
    "First proposed in this paper [Cyclical Learning Rates for Training Neural Networks](https://arxiv.org/abs/1506.01186) and the subsequently popularized by fast.ai, is a technique to reach the neighbourhood of optimum learning rate without costly search. PyTorch Tabular let's you find the optimal learning rate(using the method proposed in the paper) and automatically use that for training the network. All this can be turned on with a simple flag `auto_lr_find`\n",
    "\n",
    "**`Controlling the Gradients/Optimization`**\n",
    "\n",
    "While training, there can be situations where you need to have a heavier control on the gradient optimization process. For eg. if the gradients are exploding, you might want to clip gradient values before each update. `gradient_clip_val` let's you do that.\n",
    "\n",
    "Sometimes, you might want to accumulate gradients across multiple batches before you do a backward propoagation(may be because a larger batch size does not fit in your GPU). PyTorch Tabular let's you do this with `accumulate_grad_batches`\n",
    "\n",
    "**`Debugging Analysis`**\n",
    "\n",
    "Many times, you will need to debug a model and see why it is not performing as it is supposed to. Or even, while developing new models, you will need to debug the model a lot. PyTorch Lightning has a few features for this usecase, which Pytorch Tabular has adopted.\n",
    "\n",
    "To find out performance bottle necks, we can use:\n",
    "\n",
    "- `profiler`: Optional\\[str\\]: To profile individual steps during training and assist in identifying bottlenecks. Choices are: `None` `simple` `advanced`. Defaults to `None`\n",
    "\n",
    "To check if the whole setup runs without errors, we can use:\n",
    "\n",
    "- `fast_dev_run`: Optional\\[str\\]: Quick Debug Run of Val. Defaults to `False`\n",
    "\n",
    "If the model is not learning properly:\n",
    "\n",
    "- `overfit_batches`: float: Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it. Useful for quickly debugging or trying to overfit on purpose. Defaults to `0`\n",
    "\n",
    "- `track_grad_norm`: bool: This is only used if experiment tracking is setup. Track and Log Gradient Norms in the logger. -1 by default means no tracking. 1 for the L1 norm, 2 for L2 norm, etc. Defaults to `False`. If the gradient norm falls to zero quickly, then we have a problem.\n",
    "\n",
    "For the entire list of parameters, please refer to the API Reference in the docs.\n",
    "\n",
    "**`YAML Config`**\n",
    "\n",
    "PyTorch Tabular let's you define the any config either as a Config Class or as a YAML file. YAML files are a great way to store configs. It is human readable and easy to edit. PyTorch Tabular let's you define the config in a YAML file and pass the path of the YAML file to the respective config parameters of the `TabularModel`. \n",
    "\n",
    "We have defined a YAML file for `TrainerConfig` with the below contents:\n",
    "```yaml\n",
    "batch_size: 1024\n",
    "fast_dev_run: false\n",
    "max_epochs: 20\n",
    "min_epochs: 1\n",
    "accelerator: 'auto'\n",
    "devices: -1\n",
    "accumulate_grad_batches: 1\n",
    "auto_lr_find: true\n",
    "check_val_every_n_epoch: 1\n",
    "gradient_clip_val: 0.0\n",
    "overfit_batches: 0.0\n",
    "profiler: null\n",
    "early_stopping: null\n",
    "early_stopping_min_delta: 0.001\n",
    "early_stopping_mode: min\n",
    "early_stopping_patience: 3\n",
    "checkpoints: valid_loss\n",
    "checkpoints_path: saved_models\n",
    "checkpoints_mode: min\n",
    "checkpoints_save_top_k: 1\n",
    "load_best: true\n",
    "track_grad_norm: -1\n",
    "\n",
    "```\n",
    "\n",
    "Let's use that instead of defining the `TrainerConfig` as a class.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. `OptimizerConfig`\n",
    "\n",
    "The Optimizer is at the heart of the Gradient Descent process and is a key component that we need to train a good model, and `OptimizerConfig` let's you customize the optimizer and learning rate scheduler to your needs. Pytorch Tabular uses `Adam` optimizer with a learning rate of `1e-3` by default. This is mainly because of a rule of thumb which provides a good starting point.\n",
    "\n",
    "Let's look at a few parameters which we can use to customize the optimizer.\n",
    "\n",
    "**`optimizer`**\n",
    "\n",
    "PyTorch Tabular let's you choose any optimizer from the `torch.optim` package by passing in the name of the optimizer as a string to the `optimizer` parameter of the `OptimizerConfig`. This includes optimizers like `Adam`, `SGD`, `RMSProp`, `AdamW` etc. You can read more about these [here](https://pytorch.org/docs/stable/optim.html). In addition to these, PyTorch Tabular also supports any valid PyTorch Optimizer. If it's an optimizer which can be accessed from a namespace (like a library you installed), we can pass the fully qualified name of the optimizer to the `optimizer` parameter. For example, if you have installed the `torch_optimizer` library, and want to use `QHAdam` from there, you can pass `torch_optimizer.QHAdam` to the `optimizer` parameter. If it's an optimizer which is not accessible from a namespace, you cannot pass it in the OptimizerConfig, but will be able to use it during the `fit` which we will see later.\n",
    "\n",
    "**`optimizer_params`**\n",
    "\n",
    "PyTorch Tabular let's you pass any valid optimizer parameters (except learning rate) to the `optimizer_params` parameter of the `OptimizerConfig`. For example, if you want to use a weight decay of `1e-2`, you can pass `optimizer_params={'weight_decay':1e-2}` to the `OptimizerConfig`. You need to refer to the documentation of the optimizer you are using to find out the valid parameters. \n",
    "\n",
    "**`lr_scheduler`**\n",
    "\n",
    "Learning Schedulers are a way to control the learning rate during training. Sometimes, it is beneficial to start off with a slightly higher learning rate and reduce it as we progress in training. Sometimes, it helps if we reduce learning rate when we hit a plateau while learning. PyTorch Tabular let's you choose any learning rate scheduler from the `torch.optim.lr_scheduler` package by passing in the name of the scheduler as a string to the `lr_scheduler` parameter of the `OptimizerConfig`. This includes schedulers like `StepLR`, `ReduceLROnPlateau`, `CosineAnnealingLR` etc. You can read more about these [here](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate). \n",
    "\n",
    "**`lr_scheduler_params`**\n",
    "\n",
    "PyTorch Tabular let's you pass any valid learning rate scheduler parameters to the `lr_scheduler_params` parameter of the `OptimizerConfig`. For example, if you want to use a step size of `10` for `StepLR`, you can pass `lr_scheduler_params={'step_size':10}` to the `OptimizerConfig`. You need to refer to the documentation of the scheduler you are using to find out the valid parameters.\n",
    "\n",
    "**`lr_scheduler_monitor_metric`**\n",
    "\n",
    "This is a parameter which is used only if you are using `ReduceLROnPlateau` as the learning rate scheduler. This is the metric which will be monitored to reduce the learning rate. This should be a valid loss or metric defined in the model.\n",
    "\n",
    "\n",
    "Here, let's use a [CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) with a warmup of 10 epochs as the Learning Rate Scheduler and an optimizer from a third-party library `torch_optimizer` (You will need to install that library)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_config = OptimizerConfig(\n",
    "    optimizer=\"torch_optimizer.QHAdam\",\n",
    "    optimizer_params={\"nus\": (0.7, 1.0), \"betas\": (0.95, 0.998)},\n",
    "    lr_scheduler=\"CosineAnnealingWarmRestarts\",\n",
    "    lr_scheduler_params={\"T_0\": 10, \"T_mult\": 1, \"eta_min\": 1e-5},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. `ModelConfig`\n",
    "\n",
    "`ModelConfig` is how you decide the kind of model and the model parameters to be used in the model. PyTorch Tabular has implemented a few SOTA models for tabular data. Internally in PyTorch Tabular, a model has three components:\n",
    "\n",
    "1. Embedding Layer - This is the part of the model which processes the categorical and continuous features into a single tensor.\n",
    "1. Backbone - This is the real architecture of the model. It is the part of the model which takes the output of the embedding layer and does representation learning on it. The output is again a single tensor, which is the learned features from representation learning.\n",
    "1. Head - This is the part of the model which takes the output of the backbone and does the final classification/regression. The output of the head is the final prediction.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">[</span>\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'AutoIntConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'CategoryEmbeddingModelConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'DANetConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'FTTransformerConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'GANDALFConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'GatedAdditiveTreeEnsembleConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'MDNConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'NodeConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'TabNetModelConfig'</span>,\n",
       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #008000; text-decoration-color: #008000\">'TabTransformerConfig'</span>\n",
       "<span style=\"font-weight: bold\">]</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m[\u001b[0m\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'AutoIntConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'CategoryEmbeddingModelConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'DANetConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'FTTransformerConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'GANDALFConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'GatedAdditiveTreeEnsembleConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'MDNConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'NodeConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'TabNetModelConfig'\u001b[0m,\n",
       "\u001b[2;32m│   \u001b[0m\u001b[32m'TabTransformerConfig'\u001b[0m\n",
       "\u001b[1m]\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pytorch_tabular import available_models\n",
    "pprint(available_models())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can choose any of these models by importing the corresponding class from `pytorch_tabular.models`, set the parameters and pass it to the `model_config` parameter of the `TabularModel`. All these config classes have been inherited from a common `ModelConfig` with a few standard parameters and any model specific parameters are added to the respective config class. And because of the inheritance, we have access to all the parameters of the `ModelConfig` in all the model config classes.\n",
    "\n",
    "Let's first look at some common parameters of the `ModelConfig`:\n",
    "\n",
    "- `task`: str: This defines whether we are running the model for a `regression`, `classification` task, or as a `backbone` model. `backbone` task is used in Self-Supervised models and in Mixed Density Models.\n",
    "\n",
    "**Head Configuration**\n",
    "\n",
    "- `head`: Optional\\[str\\]: The head to be used for the model. Should be one of the heads defined in `pytorch_tabular.models.common.heads`. Defaults `LinearHead`. Below cell shows the list of available heads."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'LinearHead'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'MixtureDensityHead'</span><span style=\"font-weight: bold\">]</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m[\u001b[0m\u001b[32m'LinearHead'\u001b[0m, \u001b[32m'MixtureDensityHead'\u001b[0m\u001b[1m]\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pytorch_tabular as pt\n",
    "pprint([h for h in dir(pt.models.common.heads) if (not h.startswith(\"_\") and \"Head\" in h and \"Config\" not in h)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- `head_config`: Optional\\[Dict\\]: The config as a dict which defines the head. If left empty, will be initialized as default linear head. Although the input is a dictionary, it is recommended to use the `<Specific>HeadConfig` class for the respective head to make sure you re only using allowable parameters. For example, if you are using `LinearHead`, you can use `LinearHeadConfig` to define the head config. Below cell shows the list of available head configs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'LinearHeadConfig'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'MixtureDensityHeadConfig'</span><span style=\"font-weight: bold\">]</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m[\u001b[0m\u001b[32m'LinearHeadConfig'\u001b[0m, \u001b[32m'MixtureDensityHeadConfig'\u001b[0m\u001b[1m]\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pprint([h for h in dir(pt.models.common.heads) if (not h.startswith(\"_\") and \"Head\" in h and \"Config\" in h)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Embedding Configuration**\n",
    "\n",
    "- `embedding_dims`: Optional\\[List\\]: The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule `min(50, (x + 1) // 2)`\n",
    "- `embedding_dropout`: float: Dropout to be applied to the Categorical Embedding. Defaults to 0.0\n",
    "- `batch_norm_continuous_input`: bool: If True, we will normalize the continuous layer by passing it through a BatchNorm layer before combining with the categorical embeddings. Defaults to True\n",
    "\n",
    "**Other Configuration**\n",
    "\n",
    "- `learning_rate`: float: The learning rate of the model. Defaults to 1e-3.\n",
    "\n",
    "- `loss`: Optional\\[str\\]: The loss function to be applied. By Default it is MSELoss for regression and CrossEntropyLoss for classification. For most cases, these work well. But if you want to use any other loss function from PyTorch, you can pass it here. For example, if you want to use `BCEWithLogitsLoss`, you can pass `loss='BCEWithLogitsLoss'`. You can read more about the losses in PyTorch [here](https://pytorch.org/docs/stable/nn.html#loss-functions).  We can also use custom loss functions. We will see how to do that later.\n",
    "\n",
    "<div style=\"background-color: #fce3b3; padding: 10px; color: #e76013\">\n",
    "<b>Note</b>: Choosing the Loss Function should not be treated like a hyperparameter which you blindly apply, but a well thought out decision.</br>\n",
    "</div>\n",
    "\n",
    "- `metrics`: Optional\\[List\\[str\\]\\]: The list of metrics you need to track during training. The metrics should be one of the **functional** metrics implemented in `torchmetrics`. You can find the entire list [here](https://lightning.ai/docs/torchmetrics/stable/all-metrics.html). By default, it is `accuracy` if classification and `mean_squared_error` for regression. We can also use custom metrics. We will see how to do that later.\n",
    "\n",
    "- `metrics_prob_input`: Optional\\[List\\[bool\\]\\]: Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.\n",
    "\n",
    "- `metrics_params`: Optional\\[List\\]: The parameters to be passed to the metrics function. Some functions like the `f1_score` need additional parameters like `task` to be properly defined. This also let's you choose how to average the metric in multi-class classification. \n",
    "\n",
    "- `target_range`: Optional\\[List\\]: For classification problems, the targets are always 0 or 1, once we one-hot the class labels. But for regression, it's a real valued value between (-inf, inf), theoretically. More practically, it usually is between known bounds. Sometimes, it is an extra burden on the model to learn this bounds and `target_range` is a way to take that burden off the model. This technique was popularized by Jeremy Howard in fast.ai and is quite effective in practice. If we know that the output value of a regression should be between a `min` and `max` value, we can provide those values as a tuple to `target_range`. But a caveat is that there is an assumption that the distribution of the target is normal. If the distribution is not normal, it might not work as expected. In case of multiple targets, we set the `target_range` to be a list of tuples, each entry in the list corresponds to the respective entry in the `target` parameter. For classification problems, this parameter is ignored.\n",
    "\n",
    "```python\n",
    "target_range = [(train[target].min() * 0.8, train[target].max() * 1.2)]\n",
    "```\n",
    "\n",
    "- `virtual_batch_size`: Optional\\[int\\]: BatchNorm is a very useful technique (a necessary evil) to normalize the activations of the network. It typically leads to faster convergence, and stable training regimes. But when training with large batch sizes, BatchNorm can lead to \"overfitting\" (not in the traditional sense). One way to overcome this is to use GhostBatchNorm, where we split the batch into virtual batches and apply BatchNorm on each virtual batch. By setting `virtual_batch_size` to a number greater than 1, PyTorch Tabular will automatically convert all BatchNorms to GhostBatchNorms with the specified virtual batch size. \n",
    "\n",
    "- `seed`: int: The seed for reproducibility. Defaults to 42\n",
    "\n",
    "Now, each model we choose will have it's own set of parameters. The API Reference in the docs has the list of all the models and their respective parameters. Here let's use a simple MLP with categorical embeddings. This is called `CategoryEmbeddingModelConfig` in PyTorch Tabular.\n",
    "\n",
    "The key parameters we are going to use are:\n",
    "\n",
    "- `layers`: str: Hyphen-separated number of layers and units in the classification head. Defaults to `\"128-64-32\"`\n",
    "- `activation`: str: The activation type in the classification head. The default [activations in PyTorch](https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) like ReLU, TanH, LeakyReLU, etc. Defaults to `ReLU`\n",
    "- `initialization`: str: Initialization scheme for the linear layers. Choices are: `kaiming` `xavier` `random`. Defaults to `kaiming`\n",
    "- `use_batch_norm`: bool: Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to `False`\n",
    "- `dropout`: float: The probability of the element to be zeroed. This applies to all the linear layers. Defaults to `0.0`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\",  # No additional layer in head, just a mapping layer to output_dim\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"regression\",\n",
    "    layers=\"64-32-16\",\n",
    "    activation=\"LeakyReLU\",\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    "    learning_rate=1e-3,\n",
    "    target_range=[(float(train[col].min()),float(train[col].max())) for col in target_cols]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. `TabularModel`\n",
    "\n",
    "After defining all the configs, we need to put it all together and this is where `TabularModel` comes in. `TabularModel` is the core work horse, which orchestrates and sets everything up.\n",
    "\n",
    "`TabularModel` parses the configs and:\n",
    "\n",
    "1. initializes the model\n",
    "1. sets up the experiment tracking framework (if defined)\n",
    "1. initializes and sets up the `TabularDatamodule` which handles all the data transformations and preparation of the DataLoaders\n",
    "1. sets up the callbacks and the Pytorch Lightning Trainer\n",
    "1. enables you to train, save, load, predict, among other things\n",
    "\n",
    "Now that we have defined all the configs, let's initialize the `TabularModel` with the configs. We can pass the configs either as a class or as a YAML file. Let's use the YAML file for `TrainerConfig` and the class for the rest of the configs.\n",
    "\n",
    "Apart from the configs, we can also pass the following parameters to the `TabularModel`:\n",
    "\n",
    "- `verbose`: bool: If True, will print different messages during training indicating the progress. Defaults to `True`\n",
    "- `suppress_lightning_logger`: PyTorch Lightning prints out a lot of logs while training and this parameter let's you suppress those logs. Or to be more specific, it sets the logging level of the PyTorch Lightning logger to `ERROR`. Defaults to `False` as Pytorch Lightning logs are very useful for debugging. Only turn them off if you are sure you don't need them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:34</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">218</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">140</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:34\u001b[0m,\u001b[1;36m218\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=\"trainer_config.yml\",\n",
    "    verbose=True,\n",
    "    suppress_lightning_logger=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the Model\n",
    "\n",
    "In PyTorch Tabular, there are two ways of training the model - A High Level API and a Low Level API. The High Level API is a wrapper around the Low Level API and is the recommended way to train the model. But the Low Level API gives you more control over the training process and is useful if you want to do some custom training. Let's look at both of them.\n",
    "\n",
    "### 1. High Level API\n",
    "\n",
    "The High Level API is a single line of code which does everything for you. You just need to call the `fit` method of the `TabularModel` and it will take care of everything else. But we have already seen how to fit the model in the basic tutorial. So, let's look at a few more parameters of the `fit` method.\n",
    "\n",
    "- `loss`: This is where you can use a custom loss function. You can pass any valid PyTorch loss function to the `loss` parameter of the `fit` method. \n",
    "\n",
    "- `metrics`: This is where you can use a custom metric function. The parameter accepts a list of `Callable`s with the signature: `metric_fn(y_hat, y)`, where `y_hat` and `y` are tensors. `y_hat` is of shape `(batch_size, num_classes)` for classification and `(batch_size, 1)` for regression. `y` is of shape `(batch_size, 1)` for classification and `(batch_size, num_targets)` for regression. \n",
    "\n",
    "- `metrics_prob_inputs`: This is a mandatory parameter if you are using the `metrics` parameter. This is a list of boolean values which defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.\n",
    "\n",
    "- `optimizer` and `optimizer_params`: This is where you can use a custom optimizer. You can pass any valid PyTorch optimizer to the `optimizer` parameter of the `fit` method. You can also pass any valid optimizer parameters to the `optimizer_params` parameter of the `fit` method.\n",
    "\n",
    "- `train_samplers`: Sometimes, we would want to enforce some custom behaviour on the batch sampling. This parameter accepts any inherited class of `torch.utils.data.Sampler`. For example, if you want to use `WeightedRandomSampler`, you can pass `train_samplers=WeightedRandomSampler(...)`. You can read more about the samplers [here](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler).\n",
    "\n",
    "- `target_transform`: This parameter is a Tuple (of size 2) of Callables and let's you use any custom transformation on the target. This is useful if you want to do some custom transformation on the target before passing it to the loss function. For example, if you want to take the log of the target, you can pass `target_transform=[np.log, np.exp]`. The first function will be applied to the target before passing it to the loss function and the second function will be applied to the output of the model.\n",
    "\n",
    "- `callbacks`: PyTorch Lightning supports a lot of callbacks out of the box. You can read more about them [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.Callback.html#pytorch_lightning.callbacks.Callback). PyTorch Lightning also supports custom callbacks. These callbacks are directly added to the Lightning Trainer.\n",
    "\n",
    "- `cache_data` - By default, PyTorch Tabular saves the data in the TabularDataModule. This is useful if you want to train the model multiple times without having to load the data again and again. But if you are running out of memory, you can choose to save the files to a path and load them from there. This parameter accepts a string which is the path where the data will be saved. Default value is `memory`.\n",
    "\n",
    "There are some more parameters, but the full list of options you can read in the API reference.\n",
    "\n",
    "Let's leverage a few of these custom options to train the model.\n",
    "We are using a\n",
    "- Dummy Target Transformation\n",
    "- Custom Loss Function, which is nothing but the MSE loss. This overrides the loss function defined in the `ModelConfig`\n",
    "- Custom Optimizer, which is just `Lamb` from `torch_optimizer` (Just to demonstrate how to use a custom optimizer). This overrides the optimizer defined in the `OptimizerConfig`\n",
    "- Custom Metric Function, which is quite meaningless, but just to show how to use it. This overrides the metric defined in the `ModelConfig`\n",
    "- Custom Callback, which just prints some message during different stages of training\n",
    "\n",
    "<div style=\"background-color: #fce3b3; padding: 10px; color: #e76013\">\n",
    "<b>Note</b> </br> PyTorch Tabular passes the raw output from the models to the loss function. In classification problems, the raw output is the logits. It is the responsibility of the loss function to apply the right activations. And if you don't understand what that means, leave the loss functions at default values or use pre-implemented loss functions.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:34</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">615</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">524</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:34\u001b[0m,\u001b[1;36m615\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m524\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:34</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">631</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:499</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "regression task                                                                                                    \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:34\u001b[0m,\u001b[1;36m631\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:499\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "regression task                                                                                                    \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:34</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">893</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">574</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: CategoryEmbeddingModel \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:34\u001b[0m,\u001b[1;36m893\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m574\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:34</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">924</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">340</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:34\u001b[0m,\u001b[1;36m924\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m340\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:36</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">420</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">630</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:36\u001b[0m,\u001b[1;36m420\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m630\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "faaf11435adc4d29988fe03401b4d9e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_steps=100` reached.\n",
      "Learning rate set to 0.04365158322401657\n",
      "Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_4764cd76-0064-458b-881a-e2772b684d4d.ckpt\n",
      "Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_4764cd76-0064-458b-881a-e2772b684d4d.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:40</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">092</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">643</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.04365158322401657</span>. For plot \n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:40\u001b[0m,\u001b[1;36m092\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m643\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.04365158322401657\u001b[0m. For plot \n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:43:40</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">095</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">652</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:43:40\u001b[0m,\u001b[1;36m095\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m652\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ custom_loss      │ CustomLoss                │      0 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _backbone        │ CategoryEmbeddingBackbone │  4.5 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _embedding_layer │ Embedding1dLayer          │     92 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ head             │ LinearHead                │     34 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ custom_loss      │ CustomLoss                │      0 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  4.5 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │     92 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │     34 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 4.6 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 4.6 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 4.6 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 4.6 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "19c7c0b8a81045c3aeff72fe8ac823d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:44:11</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">064</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">663</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:44:11\u001b[0m,\u001b[1;36m064\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m663\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">11</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">14:44:11</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">065</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1487</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m11\u001b[0m \u001b[1;92m14:44:11\u001b[0m,\u001b[1;36m065\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1487\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f693b5f0b90>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from pytorch_lightning.callbacks import DeviceStatsMonitor\n",
    "from torch_optimizer import Lamb\n",
    "\n",
    "\n",
    "class CustomLoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CustomLoss, self).__init__()\n",
    "\n",
    "    def forward(self, inputs, targets):\n",
    "        loss = torch.mean((inputs - targets) ** 2)\n",
    "        return loss.mean()\n",
    "\n",
    "\n",
    "def custom_metric(y_true, y_pred):\n",
    "    return torch.mean(torch.pow(y_true - y_pred, 3))\n",
    "\n",
    "CustomOptimizer = Lamb\n",
    "## Sample of how a Custom Optimizer would look like\n",
    "# class CustomOptimizer(Optimizer):\n",
    "#     def __init__(\n",
    "#         self,\n",
    "#         params,\n",
    "#         lr: float = 1e-3,\n",
    "#         betas=(0.9, 0.999),\n",
    "#         eps: float = 1e-6,\n",
    "#         weight_decay: float = 0,\n",
    "#         clamp_value: float = 10,\n",
    "#         adam: bool = False,\n",
    "#         debias: bool = False,\n",
    "#     ):\n",
    "#         ## some code here\n",
    "#         defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n",
    "#         super().__init__(params, defaults)\n",
    "\n",
    "#     def step(self, closure=None):\n",
    "#         ## Some code here\n",
    "#         return loss\n",
    "\n",
    "\n",
    "tabular_model.fit(\n",
    "    train=train,\n",
    "    validation=val,\n",
    "    loss=CustomLoss(),\n",
    "    metrics=[custom_metric],\n",
    "    metrics_prob_inputs=[False],\n",
    "    target_transform=[lambda x: x + 100, lambda x: x - 100],\n",
    "    optimizer=CustomOptimizer,\n",
    "    optimizer_params={\"weight_decay\": 1e-6},\n",
    "    callbacks=[DeviceStatsMonitor()],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the logs of the training process (because we have set `verbose=True`), and the progress bar which shows the training and validation loss/metric. In addition to it, we can observe that the model summary was printed and some logs about availability and usaged of hardware accelerators like GPUs are printed. This is because we have not set `suppress_lightning_logger=True`. If we set that, we will not see these logs.\n",
    "\n",
    "You can further reduce the warnings from PyTorch Lightning by using the `warinings` module from Python, but it's not recommended because you might miss some important warnings.\n",
    "\n",
    "```python\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Low-Level API\n",
    "\n",
    "The low-level API is more flexible and allows you to write more complicated logic like cross validation, ensembling, etc. The low-level API is more verbose and requires you to write more code, but it comes with more control to the user.\n",
    "\n",
    "The `fit` method is split into three sub-methods:\n",
    "\n",
    "1. `prepare_dataloader`\n",
    "\n",
    "1. `prepare_model`\n",
    "\n",
    "1. `train`\n",
    "\n",
    "The parameters that we discussed in the High Level API are passed to the respective sub-methods. Before getting into the details of each of these methods, let's re-initialize the `TabularModel` and turn off logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=\"trainer_config.yml\",\n",
    "    verbose=False, # Turn off the verbose to avoid printing logs from different stages\n",
    "    suppress_lightning_logger=True, # Change Lightning Log Level to WARNING\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "#### 1. `prepare_dataloader`\n",
    "\n",
    "This method is responsible for setting up the `TabularDataModule` and returns the object. You can save this object using `save_dataloader` and load it later using `load_datamodule` to skip the data preparation step. This is useful when you are doing cross validation or ensembling.   \n",
    "\n",
    "So, parameters like `train`, `validation`, `train_sampler`, `target_transform`, `cache_data` etc. are passed to this method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "datamodule = tabular_model.prepare_dataloader(\n",
    "                train=train, validation=val, seed=42, target_transform=[lambda x: x + 100, lambda x: x - 100],\n",
    "            )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "#### 2. `prepare_model`\n",
    "\n",
    "This method is responsible for setting up and initializing the model and takes in the prepared datamodule as an input. It returns the model instance.   \n",
    "\n",
    "This method takes the `datamodule` as an input along with other parameters like `loss`, `metrics`, `metrics_prob_inputs`, `optimizer`, and `optimizer_params`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_optimizer import Lamb\n",
    "model = tabular_model.prepare_model(\n",
    "    datamodule,\n",
    "    loss=CustomLoss(),\n",
    "    metrics=[custom_metric],\n",
    "    metrics_prob_inputs=[False],\n",
    "    optimizer=Lamb,\n",
    "    optimizer_params={\"weight_decay\": 1e-6},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "#### 3. `train`\n",
    "\n",
    "This method is responsible for training the model and takes in the prepared datamodule and model as an input. It returns the trained model instance.   \n",
    "\n",
    "`train` takes the `datamodule` and `model` as an input along with other parameters like `callbacks`, `max_epochs`, `min_epochs`, and so on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ee4e67d581d467fb5bc0a4531ab379c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ custom_loss      │ CustomLoss                │      0 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _backbone        │ CategoryEmbeddingBackbone │  4.5 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _embedding_layer │ Embedding1dLayer          │     92 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ head             │ LinearHead                │     34 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ custom_loss      │ CustomLoss                │      0 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  4.5 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │     92 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │     34 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 4.6 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 4.6 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 4.6 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 4.6 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2b2d1e2ce2d4a74bf5c337953970226",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f6900d7da90>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tabular_model.train(\n",
    "    model,\n",
    "    datamodule,\n",
    "    callbacks=[DeviceStatsMonitor()],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predicting and Evaluating on New Data\n",
    "\n",
    "As we saw in the basic tutorial, we can use the `predict` method of the `TabularModel` to predict on new data. But there are a few more parameters which we can use to customize the prediction process.\n",
    "\n",
    "- `progress_bar` - This parameter lets you turn off or choose the kind of progress bar you need. if `rich`, it will use colorful, `rich` progress bars. If `tqdm`, it will use `tqdm` to show the progress bar. \"None\" or `None` will turn off the progress bar. Defaults to `rich`.\n",
    "\n",
    "- `ret_logits` - This is a boolean flag, if turned on will return the raw model output (logits) instead of the probabilities. Typically useful in classification problems. Defaults to `False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_0_prediction</th>\n",
       "      <th>target_1_prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>75721</th>\n",
       "      <td>143.383545</td>\n",
       "      <td>128.671326</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80184</th>\n",
       "      <td>57.778259</td>\n",
       "      <td>35.192749</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19864</th>\n",
       "      <td>60.018860</td>\n",
       "      <td>82.361267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76699</th>\n",
       "      <td>-124.672058</td>\n",
       "      <td>-23.280823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>92991</th>\n",
       "      <td>33.951477</td>\n",
       "      <td>102.290710</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       target_0_prediction  target_1_prediction\n",
       "75721           143.383545           128.671326\n",
       "80184            57.778259            35.192749\n",
       "19864            60.018860            82.361267\n",
       "76699          -124.672058           -23.280823\n",
       "92991            33.951477           102.290710"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction = tabular_model.predict(test, progress_bar=None)\n",
    "prediction.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also saw that we can evaluate the model on new data with existing metrics using `evaluate`. But there are a few more parameters which we can use to customize the evaluation process.\n",
    "\n",
    "- `verbose` - A flag, if `True`, will print out the results as well as return them. Defaults to True\n",
    "- `ckpt_path` - If provided, will load the model from the checkpoint path and evaluate on the data. If not provided, will use the current model and evaluate on the data. If model checkpointing was enabled, we can also use `best` to automatically load the best model. Defaults to None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "08e4adb3aad544aeb76b4a9ecac9d685",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Current Model\n",
    "result = tabular_model.evaluate(test, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3b390f3705d4d598686feb85e41c163",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">    test_custom_metric     </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     54.31338882446289     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">   test_custom_metric_0    </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    -47.28584289550781     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">   test_custom_metric_1    </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    101.59922790527344     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">      53.098876953125      </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     24.37938117980957     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_1        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    28.719507217407227     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m   test_custom_metric    \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    54.31338882446289    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m  test_custom_metric_0   \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   -47.28584289550781    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m  test_custom_metric_1   \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   101.59922790527344    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m     53.098876953125     \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    24.37938117980957    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_1       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   28.719507217407227    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Loading from a stored checkpoint path\n",
    "best_ckpt_path = tabular_model.trainer.checkpoint_callback.best_model_path\n",
    "result = tabular_model.evaluate(test, verbose=True, ckpt_path=best_ckpt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad836e882eb14a97a209c7b157ceab7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">    test_custom_metric     </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     54.31338882446289     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">   test_custom_metric_0    </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    -47.28584289550781     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">   test_custom_metric_1    </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    101.59922790527344     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">      53.098876953125      </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     24.37938117980957     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_1        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    28.719507217407227     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m   test_custom_metric    \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    54.31338882446289    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m  test_custom_metric_0   \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   -47.28584289550781    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m  test_custom_metric_1   \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   101.59922790527344    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m     53.098876953125     \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    24.37938117980957    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_1       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   28.719507217407227    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Using best checkpoint from training\n",
    "result = tabular_model.evaluate(test, verbose=True, ckpt_path=\"best\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #C8E6C9; padding: 10px; color: #1b7678\">\n",
    "<b>Congrats!</b>: You have learned how to use most of the advanced features of PyTorch Tabular. <br></br>\n",
    "\n",
    "\n",
    "Now try to use these features in your own projects and Kaggle competitions. If you have any questions, please feel free to ask them in the <a src=https://github.com/manujosephv/pytorch_tabular/discussions>GitHub Discussions</a>\n",
    "</div>"
   ]
  }
 ],
 "metadata": {
  "accelerator": "auto",
  "colab": {
   "name": "02-Advanced_Usage.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
